In [1]:
import plotly.offline as pyo

from plotly.graph_objs import *

import chart_studio.plotly as py

import pandas as pd
from pandas import DataFrame
In [2]:
from plotly import tools
In [3]:
pyo.offline.init_notebook_mode()
In [4]:
iris = pd.read_csv(r"../Data/irisDataset.csv", index_col = 0)
iris.head()
Out[4]:
Sepal length Sepal width Petal length Petal width Species
0 5.1 3.5 1.4 0.2 I. setosa
1 4.9 3.0 1.4 0.2 I. setosa
2 4.7 3.2 1.3 0.2 I. setosa
3 4.6 3.1 1.5 0.2 I. setosa
4 5.0 3.6 1.4 0.2 I. setosa
In [5]:
def scatterplotMatrix(df, scatterColumns, categoricalColumn, colours, title):
    """
    This function create a scatterplot matrix and expects the following inputs:
    - df - The DataFrame which contains the data
    - scatterColumns - a list of the columns in the DataFrame which we want to plot on a scatterplot matrix
    - categoricalColumn - the column which contains the categories of data which should be plotted
    - colours - a list of colours equal in length to the number of categories in the categoricalColumn
    - title - the title of the chart
    
    This function does not create a scatterplot where the same variable intersects with itself.
    """
    
    categories = list(df[categoricalColumn].unique())
    colourLookup = dict(zip(categories, colours))
    
    fig = tools.make_subplots(rows = len(scatterColumns),
                             cols = len(scatterColumns),
                             print_grid = True,
                             shared_xaxes = True,
                             shared_yaxes = True)
    
    diff = max(df[scatterColumns].max()) - min(df[scatterColumns].min())
    
    minimum = min(df[scatterColumns].min()) - (diff * 0.1)
    maximum = max(df[scatterColumns].max()) + (diff * 0.1)
    
    for i, column in enumerate(scatterColumns):
        fig['layout']['xaxis{}'.format(i + 1)].update({'title' : column,
                                                      'range' : [minimum,maximum]})
        
        for j, row in enumerate(scatterColumns):
            fig['layout']['yaxis{}'.format(i + 1)].update({'title' : row,
                                                      'range' : [minimum,maximum]})
            
            if column != row:
                if i == 0 and j == 1:
                    show = True
                else:
                    show = False
                
                for category, colour in colourLookup.items():
                    fig.append_trace({'type' : 'scatter',
                                     'mode' : 'markers',
                                     'x' : df.loc[df[categoricalColumn] == category, column],
                                     'y' : df.loc[df[categoricalColumn] == category, row],
                                     'marker' : {'color' : colour,
                                                'size' : 3},
                                     'name' : category,
                                     'legendgroup' : category,
                                     'showlegend' : show},
                                    col = i + 1,
                                    row = j + 1)
                    
    fig['layout'].update({'title' : title,
                         'height' : len(scatterColumns * 200),
                         'width' : len(scatterColumns * 200)})
    pyo.iplot(fig)
    return fig
In [6]:
irisScatter = scatterplotMatrix(iris, 
                  ['Sepal length','Sepal width','Petal length','Petal width'], 
                  'Species', 
                  ['purple','orange','green'],
                 'Scatterplot matrix of Iris dataset')
/Users/josh/opt/anaconda3/lib/python3.9/site-packages/plotly/tools.py:460: DeprecationWarning:

plotly.tools.make_subplots is deprecated, please use plotly.subplots.make_subplots instead

This is the format of your plot grid:
[ (1,1) x,y     ]  [ (1,2) x2,y2   ]  [ (1,3) x3,y3   ]  [ (1,4) x4,y4   ]
[ (2,1) x5,y5   ]  [ (2,2) x6,y6   ]  [ (2,3) x7,y7   ]  [ (2,4) x8,y8   ]
[ (3,1) x9,y9   ]  [ (3,2) x10,y10 ]  [ (3,3) x11,y11 ]  [ (3,4) x12,y12 ]
[ (4,1) x13,y13 ]  [ (4,2) x14,y14 ]  [ (4,3) x15,y15 ]  [ (4,4) x16,y16 ]

In [7]:
irisScatter = scatterplotMatrix(iris, 
                  ['Petal length','Petal width'], 
                  'Species', 
                  ['purple','orange','green'],
                 'Scatterplot matrix of Iris dataset')
This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]

In [8]:
iris['noCat'] = 'Iris'
In [9]:
irisScatter = scatterplotMatrix(iris, 
                  ['Petal length','Petal width'], 
                  'noCat', 
                  ['purple'],
                 'Scatterplot matrix of Iris dataset')
This is the format of your plot grid:
[ (1,1) x,y   ]  [ (1,2) x2,y2 ]
[ (2,1) x3,y3 ]  [ (2,2) x4,y4 ]

In [ ]: